import pickle
import numpy as np
import torch
import os
import shutil

Tensor = torch.DoubleTensor
torch.set_default_tensor_type('torch.DoubleTensor')

from model import Discriminator, Generator
from helpers import *

trial_id = 600
device = 'cpu'
batch_size = 64
T = 49
mask_freq = 0.2

save_path = 'saved/%03d/' % trial_id

params = pickle.load(open(save_path+'params.p', 'rb'))

# create and clean img folder
if os.path.exists('imgs'):
    shutil.rmtree('imgs')
if not os.path.exists('imgs'):
    os.makedirs('imgs')

# Init models 
G = Generator(params['G_params']).to(device)
D = Discriminator(params['D_params']).to(device)

# Load state_dicts
G_state_dict = torch.load(save_path+'model/G_final.pth')
G.load_state_dict(G_state_dict)

D_state_dict = torch.load(save_path+'model/D_final.pth')
D.load_state_dict(D_state_dict)

# Load test data
test_data = torch.Tensor(pickle.load(open('bball_data/data/basketball_eval.p', 'rb'))).transpose(0, 1)[:, :-1, :]
test_data = test_data.to(device)

# Sample test data
# TODO actually sample instead of taking the first few
data = test_data[:batch_size]
data = data.transpose(0, 1)

# Sample mask
probs = mask_freq*torch.ones(T, batch_size, 1)
mask = torch.distributions.bernoulli.Bernoulli(probs).sample()

# Impute missing data
# TODO scale lr based on batch_size
imputed_data, mask_data = impute(G, D, data, mask, n_iters=500, lambd=0.0, lr=0.1)
missing_list = []
for i in range(mask_data.shape[0]):
    if mask_data[i, 0, 0] == 0.0:
        missing_list.append(i)

# TODO save plots
mod_stats = draw_and_stats(imputed_data, 'imputed', 0, draw=True, compute_stats=True, missing_list=missing_list)

#import pdb; pdb.set_trace()
